library("dsl")
set.seed(634)

# Load data
load_data <- function(dev=FALSE) {
  # Load data
  data_test <- read.table("data/annot_test.txt", header=TRUE)
  data_train <- read.table("data/annot_train.txt", header=TRUE)
  # Combine into single data frame?

  # Set up data
  ## Only use rows where protest==1
  src_data <- rbind(
    data_test[data_test$protest==1,],
    data_train[data_train$protest==1,]
  )
  ## Reformat columns
  numeric_columns <- c("violence", "sign", "photo", "fire", "police", "children", "group_20", "flag", "night", "shouting")
  src_data[, numeric_columns] <- lapply(src_data[, numeric_columns], as.numeric)

  ## For development use a small subset of the data
  if (dev){
    src_data <- src_data[sample(1:nrow(src_data), size=round(nrow(src_data)/10), replace=FALSE),]
  }

  return(src_data)
}

# Error functions
## Binary error: can specify accuracy and "balance"
binary_error <- function(arr, acc=1.0, bal=0.5){
  # Draw n_errors*balance from observations where Ytrue=0
  # Draw n_errors*(1-balance) from observations where Ytrue=1
  n_errors <- round(length(arr) * (1 - acc))
  n_errors_0 <- min(c(round(n_errors*bal), sum(arr==0)))
  n_errors_1 <- min(c(round(n_errors*(1-bal)), sum(arr==1)))
  # Sample
  arr[sample(which(arr==0), replace=FALSE, size=n_errors_0)] <- 1
  arr[sample(which(arr==1), replace=FALSE, size=n_errors_1)] <- 0
  return(arr)
}

random_binary_error <- function(arr, acc=1.0){
  # Draw n_errors*balance from observations where Ytrue=0
  # Draw n_errors*(1-balance) from observations where Ytrue=1
  n_errors <- round(length(arr) * (1 - acc))
  # Sample
  error_idxs <- sample(1:length(arr), replace=FALSE, size=n_errors)
  arr[error_idxs] <- ifelse(arr[error_idxs]==1, 0, 1)
  return(arr)
}

## Continuous error: can specify bias
continuous_error <- function(arr, bias, sd, scale=0.1){
  # Default values based on data
  if (is.na(bias)){
    bias <- mean(arr) * scale
  }
  if (is.na(sd)){
    sd <- sd(arr) * scale
  }
  # Add normally distributed error with bias and sd
  arr <- arr + rnorm(length(arr), mean=bias, sd=sd) 
}

# Simulation function
## Get the non-dsl estimator
get_estimator_func <- function(estimator_name){
  if (estimator_name=="lm"){
    return(function(data, fml){
      return(lm(fml, data=data))
    })
  } else if (estimator_name=="logit"){
    return(function(data, fml){
      return(glm(fml, data=data, family="binomial"))
    })
  }
}

# For replacing term in formula
replace_term_in_formula <- function(fml, old_term, new_term) {
  # Check if formula is string, cast to formula if so
  if (is.character(fml)) {
    formula <- as.formula(fml)
  } else if (!inherits(fml, "formula")) {
    stop("Input must be a formula or a string representation of a formula")
  }
  
  # Convert formula for parsing
  formula_str <- Reduce(paste, deparse(fml))
  
  # Split the formula into left and right sides
  parts <- strsplit(formula_str, "~")[[1]]
  lhs <- trimws(parts[1])
  rhs <- trimws(parts[2])
  
  # Replace the term on both sides
  if (lhs == old_term)
    lhs <- new_term
  rhs <- gsub(paste0("\\b", old_term, "\\b"), new_term, rhs)
  
  # Reconstruct the formula
  new_formula <- reformulate(rhs, response = lhs)
  
  return(new_formula)
}

single_simulation <- function(sim_data, fml, surrogate_var, n_label, estimator_name, q_error_func, gs_error_func, seed=634){
  # Set seed 
  set.seed(seed)
  
  # Resample data
  sim_data <- sim_data[sample(1:nrow(sim_data), replace=TRUE),]
  
  # Record actual ground truth
  ground_truth <- sim_data[, surrogate_var]
  
  # Fit oracle model
  estimator_func <- get_estimator_func(estimator_name)
  

  # Draw surrogates
  sim_data[, "Q"] <- sim_data |> q_error_func()
  q_acc <- mean(sim_data$Q == ground_truth)
  
  # Introduce errors into gold-standard - "silver standard"
  sim_data[, surrogate_var] <- sim_data |> gs_error_func()
  gs_acc <- mean(sim_data[[surrogate_var]] == ground_truth)
  
  # Draw missingness indicator
  sim_data[,"ps"] <- rep(1/nrow(sim_data), nrow(sim_data))
  sim_data[, "R"] <- 0
  sim_data[sample(1:nrow(sim_data), size=n_label, replace=FALSE, prob=sim_data$ps), "R"] <- 1
  sim_data[sim_data$R==0, surrogate_var] <- NA
  
  # Run surrogate-only estimator
  so_fml <- replace_term_in_formula(fml, surrogate_var, "Q")
  so_model <- estimator_func(so_fml, data=sim_data)
  
  # Run DSL
  dsl_model <- dsl(
    data = as.data.frame(sim_data),
    formula = fml,
    model = estimator_name,
    labeled = "R",
    sample_prob = "ps",
    prediction = "Q",
    predicted_var = surrogate_var,
    seed = seed
  )
  
  # Get coef, se and conf.int from SO and dsl
  so_coef <- so_model$coefficients
  dsl_coef <- dsl_model$coefficients
  so_se <- coef(summary(so_model))[, "Std. Error"]
  dsl_se <- dsl_model$standard_errors
  so_confint <- confint(so_model)
  dsl_confint <- cbind(dsl_coef - 1.96 * dsl_se, dsl_coef + 1.96 * dsl_se)
  
  return(list(
    so_coef = so_coef,
    dsl_coef = dsl_coef,
    so_se = so_se,
    dsl_se = dsl_se,
    so_confint = so_confint,
    dsl_confint = dsl_confint,
    gs_acc = gs_acc,
    q_acc = q_acc
  ))
  
}

single_coverage <- function(coefs, confints){
  return((coefs >= confints[,1]) & (coefs <= confints[,2])) 
}

extract_results_multiple_coefficients <- function(results, true_coef){
  
  # Extract coefficients
  so_coefs <- sapply(results, function(x) x$so_coef)
  dsl_coefs <- sapply(results, function(x) x$dsl_coef)
  so_ses <- sapply(results, function(x) x$so_se)
  dsl_ses <- sapply(results, function(x) x$dsl_se)
  
  # Calculate diagnostics - for multiple coefs
  so_bias <- rowMeans(so_coefs - true_coef) |> abs() |> mean()
  dsl_bias <- rowMeans(dsl_coefs - true_coef) |> abs() |> mean()
  
  so_rmse <- apply(X=(so_coefs - true_coef), MARGIN=1, FUN=function(x) sqrt(mean(x^2))) |> mean()
  dsl_rmse <- apply(X=(dsl_coefs - true_coef), MARGIN=1, FUN=function(x) sqrt(mean(x^2))) |> mean()
  
  so_coverage_var <- sapply(results, function(x) single_coverage(true_coef, x$so_confint)) |> rowMeans()
  dsl_coverage_var <- sapply(results, function(x) single_coverage(true_coef, x$dsl_confint)) |> rowMeans()
  so_coverage <- mean(so_coverage_var)
  dsl_coverage <- mean(dsl_coverage_var)
  
  return(list(
    so_bias = so_bias,
    dsl_bias = dsl_bias,
    so_rmse = so_rmse,
    dsl_rmse = dsl_rmse,
    so_coverage = so_coverage,
    dsl_coverage = dsl_coverage,
    so_coverage_var = so_coverage_var,
    dsl_coverage_var = dsl_coverage_var,
    so_coefs = so_coefs,
    dsl_coefs = dsl_coefs,
    so_ses = so_ses,
    dsl_ses = dsl_ses,
    true_coef = true_coef,
    sim_results = results
  ))
}


# Run simulation `n_sims` times, calculate statistics
repeat_simulation <- function(src_data, fml, surrogate_var, n_label, estimator_name, q_error_func, gs_error_func, n_sims){
  # Run n_sims simulations 
  results <- lapply(1:n_sims, function(i){
    sim_data <- src_data # Copy the data
    return(single_simulation(sim_data, fml, surrogate_var, n_label, estimator_name, q_error_func, gs_error_func, seed=634+i))
  })
  
  return(extract_results_multiple_coefficients(results, true_coef))
}
